{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import functools\n",
    "import jax\n",
    "from jax import numpy as jnp, random, lax\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from flax import linen as nn, struct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from flax.core import Scope, init, apply, Array, lift, unfreeze"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab_type": "code",
    "outputId": "2558605e-e485-407e-b062-74d31cc49f1e",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FrozenDict({'params': {'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],\n",
      "             [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}})\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[0.17045607]], dtype=float32),\n",
       " FrozenDict({'params': {'hidden': {'bias': DeviceArray([0., 0., 0.], dtype=float32), 'kernel': DeviceArray([[-0.22119394,  0.22075175, -0.0925657 ],\n",
       "              [ 0.40571952,  0.27750877,  1.0542233 ]], dtype=float32)}, 'out': {'kernel': DeviceArray([[ 0.21448377],\n",
       "              [-0.01530595],\n",
       "              [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)}}}))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def dense(\n",
    "    scope: Scope,\n",
    "    inputs: Array,\n",
    "    features: int,\n",
    "    bias: bool = True,\n",
    "    kernel_init=nn.linear.default_kernel_init,\n",
    "    bias_init=nn.initializers.zeros_init(),\n",
    "):\n",
    "  kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))\n",
    "  y = jnp.dot(inputs, kernel)\n",
    "  if bias:\n",
    "    y += scope.param('bias', bias_init, (features,))\n",
    "  return y\n",
    "\n",
    "\n",
    "model_fn = functools.partial(dense, features=3)\n",
    "\n",
    "x = jnp.ones((1, 2))\n",
    "y, params = init(model_fn)(random.key(0), x)\n",
    "print(params)\n",
    "\n",
    "\n",
    "def mlp(scope: Scope, inputs: Array, features: int):\n",
    "  hidden = scope.child(dense, 'hidden')(inputs, features)\n",
    "  hidden = nn.relu(hidden)\n",
    "  return dense(scope.push('out'), hidden, 1)\n",
    "\n",
    "\n",
    "init(mlp)(random.key(0), x, features=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab_type": "code",
    "outputId": "5790b763-df4f-47c8-9f4e-53fd1e1eb1fd",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.11575121 -0.51936364 -1.113899  ]\n",
      " [ 0.45569834 -0.5300623  -0.5873911 ]]\n",
      "[ 0.45569834 -0.5300623  -0.5873911 ]\n",
      "[[-1.5175114 -0.6617551]]\n"
     ]
    }
   ],
   "source": [
    "@struct.dataclass\n",
    "class Embedding:\n",
    "  table: np.ndarray\n",
    "\n",
    "  def lookup(self, indices):\n",
    "    return self.table[indices]\n",
    "\n",
    "  def attend(self, query):\n",
    "    return jnp.dot(query, self.table.T)\n",
    "\n",
    "\n",
    "# all the embedding module does is provide a convenient initializers\n",
    "\n",
    "\n",
    "def embedding(\n",
    "    scope: Scope,\n",
    "    num_embeddings: int,\n",
    "    features: int,\n",
    "    init_fn=nn.linear.default_embed_init,\n",
    ") -> Embedding:\n",
    "  table = scope.param('table', init_fn, (num_embeddings, features))\n",
    "  return Embedding(table)\n",
    "\n",
    "\n",
    "embedding, _ = init(embedding)(random.key(0), num_embeddings=2, features=3)\n",
    "print(embedding.table)\n",
    "print(embedding.lookup(1))\n",
    "print(\n",
    "    embedding.attend(\n",
    "        jnp.ones((\n",
    "            1,\n",
    "            3,\n",
    "        ))\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab_type": "code",
    "outputId": "dd9c5079-10e7-4944-e09a-e9f65573a733"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((((1, 3), (1, 3)), (1, 3)),\n",
       " FrozenDict({'params': {'hf': {'bias': (3,), 'kernel': (3, 3)}, 'hg': {'bias': (3,), 'kernel': (3, 3)}, 'hi': {'bias': (3,), 'kernel': (3, 3)}, 'ho': {'bias': (3,), 'kernel': (3, 3)}, 'if': {'kernel': (2, 3)}, 'ig': {'kernel': (2, 3)}, 'ii': {'kernel': (2, 3)}, 'io': {'kernel': (2, 3)}}}))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def lstm(\n",
    "    scope,\n",
    "    carry,\n",
    "    inputs,\n",
    "    gate_fn=nn.activation.sigmoid,\n",
    "    activation_fn=nn.activation.tanh,\n",
    "    kernel_init=nn.linear.default_kernel_init,\n",
    "    recurrent_kernel_init=nn.initializers.orthogonal(),\n",
    "    bias_init=nn.initializers.zeros_init(),\n",
    "):\n",
    "  r\"\"\"A long short-term memory (LSTM) cell.\n",
    "\n",
    "  the mathematical definition of the cell is as follows\n",
    "  .. math::\n",
    "      \\begin{array}{ll}\n",
    "      i = \\sigma(W_{ii} x + W_{hi} h + b_{hi}) \\\\\n",
    "      f = \\sigma(W_{if} x + W_{hf} h + b_{hf}) \\\\\n",
    "      g = \\tanh(W_{ig} x + W_{hg} h + b_{hg}) \\\\\n",
    "      o = \\sigma(W_{io} x + W_{ho} h + b_{ho}) \\\\\n",
    "      c' = f * c + i * g \\\\\n",
    "      h' = o * \\tanh(c') \\\\\n",
    "      \\end{array}\n",
    "  where x is the input, h is the output of the previous time step, and c is\n",
    "  the memory.\n",
    "\n",
    "  Args:\n",
    "    carry: the hidden state of the LSTM cell,\n",
    "      initialized using `LSTMCell.initialize_carry`.\n",
    "    inputs: an ndarray with the input for the current time step.\n",
    "      All dimensions except the final are considered batch dimensions.\n",
    "    gate_fn: activation function used for gates (default: sigmoid)\n",
    "    activation_fn: activation function used for output and memory update\n",
    "      (default: tanh).\n",
    "    kernel_init: initializer function for the kernels that transform\n",
    "      the input (default: lecun_normal).\n",
    "    recurrent_kernel_init: initializer function for the kernels that transform\n",
    "      the hidden state (default: orthogonal).\n",
    "    bias_init: initializer for the bias parameters (default: zeros_init())\n",
    "  Returns:\n",
    "    A tuple with the new carry and the output.\n",
    "  \"\"\"\n",
    "  c, h = carry\n",
    "  hidden_features = h.shape[-1]\n",
    "  # input and recurrent layers are summed so only one needs a bias.\n",
    "  dense_h = lambda name: scope.child(dense, name)(\n",
    "      h,\n",
    "      features=hidden_features,\n",
    "      bias=True,\n",
    "      kernel_init=recurrent_kernel_init,\n",
    "      bias_init=bias_init,\n",
    "  )\n",
    "  dense_i = lambda name: scope.child(dense, name)(\n",
    "      inputs, features=hidden_features, bias=False, kernel_init=kernel_init\n",
    "  )\n",
    "  i = gate_fn(dense_i(name='ii') + dense_h(name='hi'))\n",
    "  f = gate_fn(dense_i(name='if') + dense_h(name='hf'))\n",
    "  g = activation_fn(dense_i(name='ig') + dense_h(name='hg'))\n",
    "  o = gate_fn(dense_i(name='io') + dense_h(name='ho'))\n",
    "  new_c = f * c + i * g\n",
    "  new_h = o * activation_fn(new_c)\n",
    "  return (new_c, new_h), new_h\n",
    "\n",
    "\n",
    "def lstm_init_carry(batch_dims, size, init_fn=jnp.zeros):\n",
    "  shape = batch_dims + (size,)\n",
    "  return init_fn(shape), init_fn(shape)\n",
    "\n",
    "\n",
    "x = jnp.ones((1, 2))\n",
    "carry = lstm_init_carry((1,), 3)\n",
    "y, variables = init(lstm)(random.key(0), carry, x)\n",
    "jax.tree_util.tree_map(np.shape, (y, variables))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initialized parameter shapes:\n",
      " {'params': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}\n"
     ]
    }
   ],
   "source": [
    "def simple_scan(scope: Scope, xs):\n",
    "  init_carry = lstm_init_carry(xs.shape[:1], xs.shape[-1])\n",
    "  #   cell = scope.child(lstm, 'cell')\n",
    "  #   ys = []\n",
    "  #   for i in range(xs.shape[1]):\n",
    "  #       x = xs[:, i]\n",
    "  #       init_carry, y = cell(init_carry, x)\n",
    "  #       ys.append(y)\n",
    "  #   return init_carry, ys\n",
    "  lstm_scan = lift.scan(\n",
    "      lstm,\n",
    "      in_axes=1,\n",
    "      out_axes=1,\n",
    "      variable_broadcast='params',\n",
    "      split_rngs={'params': False},\n",
    "  )\n",
    "  return lstm_scan(scope, init_carry, xs)\n",
    "\n",
    "\n",
    "key1, key2 = random.split(random.key(0), 2)\n",
    "xs = random.uniform(key1, (1, 5, 2))\n",
    "\n",
    "\n",
    "y, init_variables = init(simple_scan)(key2, xs)\n",
    "\n",
    "print(\n",
    "    'initialized parameter shapes:\\n',\n",
    "    jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output:\n",
      " (DeviceArray([[-0.35626447,  0.25178757]], dtype=float32), DeviceArray([[-0.17885922,  0.13063088]], dtype=float32))\n"
     ]
    }
   ],
   "source": [
    "y = apply(simple_scan)(init_variables, xs)[0]\n",
    "print('output:\\n', y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 0
}
